import numpy as np
from tqdm.auto import tqdm
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter
import pandas as pd
from tqdm import tqdm
import json
plot_fit = False 
table_cached = None
slope_latency_model = None
intercept_latency_model = None
# create a function that takes in uncached_tokens and returns the latency
def uncached_tokens_to_latency(uncached_tokens):
    global table_cached
    global plot_fit
    global slope_latency_model
    global intercept_latency_model
    if table_cached is None:
        table_cached = pd.read_csv("./vllm_latency_experiments_repeats.csv")
        mean_durations = table_cached.groupby(['cache_tokens', 'token_load'])['duration_s'].mean()

        # Convert the resulting Series (with MultiIndex) to a DataFrame for easier filtering.
        # Columns will be: cache_tokens, token_load, duration_s (now holding mean values)
        mean_durations_df = mean_durations.reset_index()

        # 2. Filter to keep only entries where 'cache_tokens' is 0.
        #    Make a copy to ensure 'filtered_data' is a new DataFrame and avoid potential SettingWithCopyWarning.
        filtered_data = mean_durations_df[mean_durations_df['cache_tokens'] == 0].copy()

        # 3. Rename 'token_load' to 'uncached_tokens'.
        #    This operation is done on the 'filtered_data' DataFrame.
        #    We are interested in the 'token_load' (to become 'uncached_tokens') and 'duration_s' columns.
        filtered_data.rename(columns={'token_load': 'uncached_tokens'}, inplace=True)
        table_cached = filtered_data.set_index('uncached_tokens')['duration_s']

        x_values_for_fit = filtered_data['uncached_tokens'].values.astype(float)
        y_values_for_fit = filtered_data['duration_s'].values.astype(float)

        # Perform linear regression using np.polyfit (degree 1 for linear)
        # np.polyfit returns an array [slope, intercept]
        linear_fit_coefficients = np.polyfit(x_values_for_fit, y_values_for_fit, 1)

        # Store the slope and intercept in global variables
        # These can be used later, for example, to predict latency based on the linear model.
        slope_latency_model = linear_fit_coefficients[0]
        intercept_latency_model = linear_fit_coefficients[1]
        # check MSE error
        mse_error = np.mean((y_values_for_fit - (slope_latency_model * x_values_for_fit + intercept_latency_model)) ** 2)
        print(f"MSE error: {mse_error:.8f}")
        print(f"Slope: {slope_latency_model:.8f}, Intercept: {intercept_latency_model:.8f}")
    # print the mse error
    if plot_fit:
        # create a plot of the data and the fit
        plt.scatter(x_values_for_fit, y_values_for_fit, label='Data')
        plt.plot(x_values_for_fit, slope_latency_model * x_values_for_fit + intercept_latency_model, label='Linear Fit', color='red')
        plt.xlabel('Uncached Tokens')
        plt.ylabel('Latency (s)')
        plt.legend()
        plt.savefig("latency_fit.png", dpi=300)
        plot_fit = False
    # import pdb; pdb.set_trace()

    return slope_latency_model * uncached_tokens + intercept_latency_model

def Belady_cache_policy(a_list, C, forced = 1):
    """
    Implements the Belady's optimal cache replacement policy.

    Args:
        a_list (pd.DataFrame): DataFrame of conversation arrivals. Each row should contain
                               'conv_idx', 'history_length', 'response_length', 
                               and 'next_arrival_time'.
        C (int): Cache capacity.
        forced (int, optional): If 1, forces caching of the current arrival. Defaults to 1.

    Returns:
        list: A list containing the number of uncached tokens for each arrival.
    """

    N = max(a_list['conv_idx']) - min(a_list['conv_idx']) + 1
    x_state = np.zeros(N+1) # save cache state
    num_uncached_tokens = []
    next_arrivals_dict = {}
    for _, arrival in tqdm(a_list.iterrows()):
        try:
            if arrival['history_length'] > C and forced == 1:
                print('Capacity is too small to allow forced caching.')
                break
            current_arrival_id = int(arrival['conv_idx'])
            x_arrival = x_state[current_arrival_id]
            next_arrivals_dict[current_arrival_id] = arrival['next_arrival_time']
            num_uncached_tokens.append(arrival['history_length'] - x_arrival)

            # Set cache amount for arriving conversation
            x_state[current_arrival_id] =  arrival['history_length'] + arrival['response_length'] # forced caching

            if sum(x_state) > C: # Belady's eviction rule
                # Sort by arrival time (descending)
                next_arrivals = [(idx, time) for idx, time in next_arrivals_dict.items()]
                # Sort by arrival time (descending)
                conv_order = [idx for idx, _ in sorted(next_arrivals, key=lambda x: x[1], reverse=True)]
                for p in conv_order:
                    if forced == 1 and p == current_arrival_id:
                        continue
                    if sum(x_state) <= C:
                        break
                    evict_amount = min(x_state[p], sum(x_state) - C)
                    x_state[p] -= evict_amount
        except Exception as e:
            import pdb; pdb.set_trace()
    return num_uncached_tokens


def run_policies_and_plot_tail_metrics(a_list, C_values, xi=200, Q=50, forced=1):
    """
    Run different caching policies with varying cache capacities and plot tail metrics.
    
    Args:
        a_list: List of conversation arrivals
        C_values: List of cache capacities to test
        xi: Parameter for target cache calculation. You can choose it as the percentile as you'd like to optimize for. 
        Q: Parameter for LRU's estimated next prompt length. The smaller the Q, the more proactive and further away from the LRU policy. 
    """
    # Define percentiles to track
    percentiles = [90, 95, 99]
    
    # Initialize results dictionaries for each policy
    belady_results = {p: [] for p in percentiles}
    lru_results = {p: [] for p in percentiles}
    vanilla_belady_results = {p: [] for p in percentiles}
    vanilla_lru_results = {p: [] for p in percentiles}
    
    # Add new result dictionaries for the different predictors
    lru_end_results = {p: [] for p in percentiles}
    lru_perfect_results = {p: [] for p in percentiles}
    thre_lru_results = {p: [] for p in percentiles}
    # Run policies for each cache capacity
    for C in tqdm(C_values, desc="Testing cache capacities"):
        # Run Belady's optimal policy (tail optimized)
        belady_uncached = tail_optimized_Belady_cache_policy(a_list, C, xi, forced)
        
        # Run LRU policy (tail optimized) with different predictors
        lru_uncached = tail_optimized_LRU_cache_policy(a_list, C, xi, Q, predictor='None', forced=forced)
        lru_end_uncached = tail_optimized_LRU_cache_policy(a_list, C, xi, Q, predictor='End', forced=forced)
        lru_perfect_uncached = tail_optimized_LRU_cache_policy(a_list, C, xi, Q, predictor='Perfect', forced=forced)
        
        # Run vanilla Belady policy
        vanilla_belady_uncached = Belady_cache_policy(a_list, C, forced)
        
        # Run vanilla LRU policy  
        vanilla_lru_uncached = LRU_cache_policy(a_list, C, forced)
        
        # Run threshold LRU policy
        thre_lru_uncached = thre_lru_cache_policy(a_list, C, threshold=1024, forced=forced)
        
        # import pdb; pdb.set_trace()
        # Calculate percentiles for each policy
        for p in percentiles:
            belady_results[p].append(np.percentile(belady_uncached, p))
            lru_results[p].append(np.percentile(lru_uncached, p))
            lru_end_results[p].append(np.percentile(lru_end_uncached, p))
            lru_perfect_results[p].append(np.percentile(lru_perfect_uncached, p))
            vanilla_belady_results[p].append(np.percentile(vanilla_belady_uncached, p))
            vanilla_lru_results[p].append(np.percentile(vanilla_lru_uncached, p))
            thre_lru_results[p].append(np.percentile(thre_lru_uncached, p))
    # Create plots
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle("Tail Metrics of Uncached Tokens vs. Cache Capacity", fontsize=16)
    
    # Flatten axes for easier iteration
    axes = axes.flatten()
    
    # Plot each percentile in a separate subplot
    for i, p in enumerate(percentiles):
        ax = axes[i]
        
        # Plot all policies for this percentile
        ax.plot(C_values, belady_results[p], 'o-', label='Tail-Optimized Belady')
        ax.plot(C_values, lru_results[p], 's-', label='Tail-Optimized LRU (None)')
        ax.plot(C_values, lru_end_results[p], 'D-', label='Tail-Optimized LRU (End)')
        ax.plot(C_values, lru_perfect_results[p], 'P-', label='Tail-Optimized LRU (Perfect)')
        ax.plot(C_values, vanilla_belady_results[p], '^--', label='Vanilla Belady')
        ax.plot(C_values, vanilla_lru_results[p], 'v--', label='Vanilla LRU')
        ax.plot(C_values, thre_lru_results[p], 'x--', label='Threshold LRU')
        ax.set_xlabel('Cache Capacity (C)')
        ax.set_ylabel(f'{p}th Percentile of Uncached Tokens')
        ax.set_title(f'{p}th Percentile')
        ax.grid(True, alpha=0.3)
        
        if i == 0:
            ax.legend()
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.92)
    plt.savefig("tail_metrics_comparison.png", dpi=300)
    plt.show()
    
    
    # Return results for further analysis if needed
    return {
        'belady': belady_results,
        'lru': lru_results,
        'lru_end': lru_end_results,
        'lru_perfect': lru_perfect_results,
        'vanilla_belady': vanilla_belady_results,
        'vanilla_lru': vanilla_lru_results
    }





def thre_lru_cache_policy(a_list, C, threshold, forced=1):
    """
    Implements a Threshold LRU cache policy.
    If the total tokens to cache for an arrival (history + response) is below a certain threshold,
    it is not cached. Otherwise, it attempts to cache it. If the cache exceeds capacity,
    it evicts tokens using the Least Recently Used (LRU) strategy.

    Args:
        a_list (pd.DataFrame): DataFrame of arrivals, each with 'conv_idx', 'history_length', 
                               'response_length', 'timestamp'.
        C (int): Cache capacity.
        threshold (int): The minimum size of content (history + response) to be considered for caching.
        forced (int, optional): If 1, the current arrival's history is forced into the cache,
                                and it won't be evicted in the current step. Defaults to 1.

    Returns:
        list: A list of the number of uncached tokens for each arrival.
    """
    N = max(a_list['conv_idx']) - min(a_list['conv_idx']) + 1
    x_state = np.zeros(N + 1)  # save cache state per conv_idx
    num_uncached_tokens = []
    last_arrivals_dict = {}

    for _, arrival in tqdm(a_list.iterrows(), total=len(a_list)):
        if arrival['history_length'] > C and forced == 1:
            print('Capacity is too small to allow forced caching.')
            break

        current_arrival_id = int(arrival['conv_idx'])
        last_arrivals_dict[current_arrival_id] = arrival['timestamp']
        x_arrival = x_state[current_arrival_id]
        num_uncached_tokens.append(arrival['history_length'] - x_arrival)

        total_to_cache = arrival['history_length'] + arrival['response_length']
        if total_to_cache <= threshold:
            continue  # skip caching if below threshold

        # Attempt to cache
        x_state[current_arrival_id] = total_to_cache

        # If over capacity, evict using LRU
        if sum(x_state) > C:
            # Sort convs by oldest timestamp first
            conv_order = [idx for idx, _ in sorted(last_arrivals_dict.items(), key=lambda x: x[1])]
            for p in conv_order:
                if forced == 1 and p == current_arrival_id:
                    continue
                if sum(x_state) <= C:
                    break
                evict_amount = min(x_state[p], sum(x_state) - C)
                x_state[p] -= evict_amount

    return num_uncached_tokens


def tail_optimized_Belady_cache_policy(a_list, C, xi, forced = 1):
    """
    Implements the tail-optimized Belady's cache replacement policy.
    This is an offline optimal policy designed for the Tail Value at Threshold (TVaT) objective.
    It considers future prompt lengths to determine an "evict-for-free" portion of cached items.

    Args:
        a_list (pd.DataFrame): DataFrame of conversation arrivals. Each row must contain:
                               'conv_idx' (int): Conversation identifier.
                               'history_length' (int): Length of the current history.
                               'response_length' (int): Length of the current response.
                               'next_arrival_time' (float): Timestamp of the next arrival for this conversation.
                                                            Use float('inf') if this is the last turn.
                               'next_prompt_length' (int): Length of the next prompt in this conversation.
        C (int): The total cache capacity.
        xi (int): A parameter influencing the "evict-for-free" calculation.
                  The target amount to keep for a conversation (beyond which it's "free" to evict)
                  is calculated as:
                  `target_cache = max(0, (history_length + response_length) + next_prompt_length - xi)`.
                  A larger `xi` reduces the "evict-for-free" portion.
        forced (int, optional): Controls the eviction eligibility of the currently arriving item.
                                If 1 (default): The current item, after being cached to its full size
                                (history + response), is protected from eviction in the same processing step.
                                This applies to both the "evict-for-free" phase and the subsequent
                                Belady eviction phase.
                                If 0: The current item is cached to its full size but is immediately
                                eligible for eviction if space is needed, similar to other items
                                already in the cache.
                                Defaults to 1.

    Returns:
        list: A list where each element is the number of uncached tokens for the
              corresponding arrival in `a_list`.
    """
    # N = len(set(d['conv_idx'] for d in a_list))
    N = max(a_list['conv_idx']) - min(a_list['conv_idx']) + 1
    x_state = np.zeros(N+1) # save cache state
    num_uncached_tokens = []
    next_arrivals_dict = {}
    free_x_state = {} # save target cache ("evict-for-free" part)
    free_x_state_list = [] # save conv_id for those who has part of free eviction cache

    # change iteration over a_list rows 
    for _, arrival in tqdm(a_list.iterrows()):
        if arrival['history_length'] > C and forced == 1:
            print('Capacity is too small to allow forced caching.')
            break
        # import pdb; pdb.set_trace()
        current_arrival_id = int(arrival['conv_idx'])
        x_arrival = x_state[current_arrival_id]
        next_arrivals_dict[current_arrival_id] = arrival['next_arrival_time']
        num_uncached_tokens.append(arrival['history_length'] - x_arrival)

        # Set cache amount for arriving conversation
        if arrival['next_arrival_time'] == float('inf'): # if this is the last turn
            target_cache = 0
        else:
            target_cache = max(0, arrival['history_length'] + arrival['response_length'] + arrival['next_prompt_length'] - xi)
        free_x_state[current_arrival_id] = min(arrival['history_length'] +  arrival['response_length'], target_cache)
        
        x_state[current_arrival_id] =  arrival['history_length'] + arrival['response_length'] # forced caching
        
        if current_arrival_id not in free_x_state_list:
            free_x_state_list.append(current_arrival_id)

        if sum(x_state) > C: # phase 1: evict for free
            to_remove = []
            for p in free_x_state_list:
                if forced == 1 and p == current_arrival_id:
                    continue
                if sum(x_state) <= C:
                    break
                evict_amount = min(x_state[p] - free_x_state[p], sum(x_state) - C)
                x_state[p] -= evict_amount
                if x_state[p] == free_x_state[p]:
                    to_remove.append(p)
            for p in to_remove:
                if p in free_x_state_list:
                    free_x_state_list.remove(p)

        if sum(x_state) > C: # phase 2: Belady's eviction rule
            # Sort by arrival time (descending)
            next_arrivals = [(idx, time) for idx, time in next_arrivals_dict.items()]
            # Sort by arrival time (descending)
            conv_order = [idx for idx, _ in sorted(next_arrivals, key=lambda x: x[1], reverse=True)]
            for p in conv_order:
                if forced == 1 and p == current_arrival_id:
                    continue
                if sum(x_state) <= C:
                    break
                evict_amount = min(x_state[p], sum(x_state) - C)
                x_state[p] -= evict_amount
    return num_uncached_tokens

# def Belady_cache_policy(a_list, C, forced = 1):
#     N = len(set(d['conv_idx'] for d in a_list))
#     x_state = np.zeros(N+1) # save cache state
#     num_uncached_tokens = []
#     next_arrivals_dict = {}
#     for arrival in tqdm(a_list):
#         if arrival['history_length'] > C and forced == 1:
#             print('Capacity is too small to allow forced caching.')
#             break
#         current_arrival_id = arrival['conv_idx']
#         x_arrival = x_state[current_arrival_id]
#         next_arrivals_dict[current_arrival_id] = arrival['next_arrival_time']
#         num_uncached_tokens.append(arrival['history_length'] - x_arrival)

#         # Set cache amount for arriving conversation
#         x_state[current_arrival_id] =  arrival['history_length'] + arrival['response_length'] # forced caching

#         if sum(x_state) > C: # Belady's eviction rule
#             # Sort by arrival time (descending)
#             next_arrivals = [(idx, time) for idx, time in next_arrivals_dict.items()]
#             # Sort by arrival time (descending)
#             conv_order = [idx for idx, _ in sorted(next_arrivals, key=lambda x: x[1], reverse=True)]
#             for p in conv_order:
#                 if forced == 1 and p == current_arrival_id:
#                     continue
#                 if sum(x_state) <= C:
#                     break
#                 evict_amount = min(x_state[p], sum(x_state) - C)
#                 x_state[p] -= evict_amount
#     return num_uncached_tokens

# Q is the next prompt length one used to compute how many caches to evict in the first phase   
# if predictor = 'End', then we know if there will be a next turn, but we don't know the next user prompt length
# if predictor = 'Perfect', then we can use the true next user prompt length to compute how many caches to evict in the first phase 
def tail_optimized_LRU_cache_policy(a_list, C, xi, Q, predictor = 'None', forced = 1): 
    N = max(a_list['conv_idx']) - min(a_list['conv_idx']) + 1
    try:
        x_state = np.zeros(N+1) # save cache state
        num_uncached_tokens = []
        last_arrivals_dict = {}
        free_x_state = {} # save target cache ("evict-for-free" part)
        free_x_state_list = [] # save conv_id for those who has part of free eviction cache

        for _, arrival in tqdm(a_list.iterrows()):
            # print(arrival)
            if arrival['history_length'] > C and forced == 1:
                print('Capacity is too small to allow forced caching.')
                break
            current_arrival_id = int(arrival['conv_idx'])
            x_arrival = x_state[current_arrival_id]
            last_arrivals_dict[current_arrival_id] = arrival['timestamp']
            num_uncached_tokens.append(arrival['history_length'] - x_arrival)
            if predictor != 'None':
                if arrival['next_arrival_time'] == float('inf'): # if this is the last turn
                    target_cache = 0
                else:
                    if predictor == 'End':
                        target_cache = max(0, arrival['history_length'] + arrival['response_length'] + Q - xi)
                    else:
                        target_cache = max(0, arrival['history_length'] + arrival['response_length'] + arrival['next_prompt_length'] - xi)
            else:
                target_cache = max(0, arrival['history_length'] + arrival['response_length'] + Q - xi)
            free_x_state[current_arrival_id] = min(arrival['history_length'] +  arrival['response_length'], target_cache)
            
            x_state[current_arrival_id] =  arrival['history_length'] + arrival['response_length'] # forced caching
            
            if current_arrival_id not in free_x_state_list:
                free_x_state_list.append(current_arrival_id)

            if sum(x_state) > C: # phase 1: evict for free
                to_remove = []
                for p in free_x_state_list:
                    if forced == 1 and p == current_arrival_id:
                        continue
                    if sum(x_state) <= C:
                        break
                    evict_amount = min(x_state[p] - free_x_state[p], sum(x_state) - C)
                    x_state[p] -= evict_amount
                    if x_state[p] == free_x_state[p]:
                        to_remove.append(p)
                for p in to_remove:
                    if p in free_x_state_list:
                        free_x_state_list.remove(p)

            if sum(x_state) > C: # phase 2: LRU's eviction rule
                # Sort by arrival time (descending)
                last_arrivals = [(idx, time) for idx, time in last_arrivals_dict.items()]
                # Sort by arrival time (descending)
                conv_order = [idx for idx, _ in sorted(last_arrivals, key=lambda x: x[1], reverse=False)]
                for p in conv_order:
                    if forced == 1 and p == current_arrival_id:
                        continue
                    if sum(x_state) <= C:
                        break
                    evict_amount = min(x_state[p], sum(x_state) - C)
                    x_state[p] -= evict_amount
    except Exception as e:
        print(e)
        import pdb; pdb.set_trace()
    return num_uncached_tokens

def LRU_cache_policy(a_list, C, forced = 1): # Q is the next prompt length one used to compute how many caches to evict in the first phase   
    N = max(a_list['conv_idx']) - min(a_list['conv_idx']) + 1
    x_state = np.zeros(N+1) # save cache state
    num_uncached_tokens = []
    last_arrivals_dict = {}

    for _, arrival in tqdm(tqdm(a_list.iterrows())):
    # make the above tqdm as well
    
        if arrival['history_length'] > C and forced == 1:
            print('Capacity is too small to allow forced caching.')
            break
        current_arrival_id = int(arrival['conv_idx'])
        x_arrival = x_state[current_arrival_id]
        last_arrivals_dict[current_arrival_id] = arrival['timestamp']
        num_uncached_tokens.append(arrival['history_length'] - x_arrival)

        x_state[current_arrival_id] =  arrival['history_length'] + arrival['response_length'] # forced caching
        if sum(x_state) > C: # LRU's eviction rule
            # Sort by arrival time (descending)
            last_arrivals = [(idx, time) for idx, time in last_arrivals_dict.items()]
            # Sort by arrival time (descending)
            conv_order = [idx for idx, _ in sorted(last_arrivals, key=lambda x: x[1], reverse=False)]
            for p in conv_order:
                if forced == 1 and p == current_arrival_id:
                    continue
                if sum(x_state) <= C:
                    break
                evict_amount = min(x_state[p], sum(x_state) - C)
                x_state[p] -= evict_amount
    return num_uncached_tokens

def process_arrival(arrival, x_state, free_x_state, free_x_state_list, next_arrivals_dict, C, xi, forced):
    current_arrival_id = int(arrival['conv_idx'])
    x_arrival = x_state[current_arrival_id]
    next_arrivals_dict[current_arrival_id] = arrival['next_arrival_time']
    uncached_tokens = arrival['history_length'] - x_arrival

    # Set cache amount for arriving conversation
    if arrival['next_arrival_time'] == float('inf'):  # if this is the last turn
        target_cache = 0
    else:
        target_cache = max(0, arrival['history_length'] + arrival['response_length'] + arrival['next_prompt_length'] - xi)
    free_x_state[current_arrival_id] = min(arrival['history_length'] + arrival['response_length'], target_cache)

    x_state[current_arrival_id] = arrival['history_length'] + arrival['response_length']  # forced caching

    if current_arrival_id not in free_x_state_list:
        free_x_state_list.append(current_arrival_id)

    return uncached_tokens

def evict_for_free(x_state, free_x_state, free_x_state_list, C, forced, current_arrival_id):
    to_remove = []
    for p in free_x_state_list:
        if forced == 1 and p == current_arrival_id:
            continue
        if sum(x_state) <= C:
            break
        evict_amount = min(x_state[p] - free_x_state[p], sum(x_state) - C)
        x_state[p] -= evict_amount
        if x_state[p] == free_x_state[p]:
            to_remove.append(p)
    for p in to_remove:
        if p in free_x_state_list:
            free_x_state_list.remove(p)

def belady_eviction(x_state, next_arrivals_dict, C, forced, current_arrival_id):
    next_arrivals = [(idx, time) for idx, time in next_arrivals_dict.items()]
    conv_order = [idx for idx, _ in sorted(next_arrivals, key=lambda x: x[1], reverse=True)]
    for p in conv_order:
        if forced == 1 and p == current_arrival_id:
            continue
        if sum(x_state) <= C:
            break
        evict_amount = min(x_state[p], sum(x_state) - C)
        x_state[p] -= evict_amount



def load_data(name):
    with open('./saved_data/arrival_list_{}.json'.format(name), 'r') as f:
        arrival_list = json.load(f)
    return arrival_list


if __name__ == "__main__":
    # arrival_list = load_data("WildChat")
    # print(arrival_list[0])
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = "6,7"

    C_values = [1000,2000,4000]  # Your cache capacity values
    a_list = load_data("ShareGPT_easy_lambda_conv_1_lambda_turn_5")    # Your access list
    # sub-sample the data
    #  a_list = load_data("WildChat")    # Your access list
    # change to dataframe
    a_list = pd.DataFrame(a_list)
    # make the conv_idx as integer
    a_list['conv_idx'] = a_list['conv_idx'].astype(int)
    # filter the dataframe to only include conv_idx 1-500
    a_list = a_list[a_list['conv_idx'].isin(range(1, 51))]    
    xi = 500         # Your xi parameter
    Q = 200          # Your Q parameter
    forced = 0      # Your forced parameter
    percentiles = [90, 95, 99]  # Your percentiles
    run_policies_and_plot_tail_metrics(a_list, C_values, xi=500, Q=200, forced=0)
    
    
    

